import torch
# import cv2 as cv
from model.yolov5 import YoloV5
from utils.dataset import ZhnDatasetClass, ZhnDatasetDetect
from utils.loss import total_loss
from utils.utils import freeze_param, param_disturb

train_backbone = True
use_cuda = True
train_dataset = ZhnDatasetClass() if train_backbone else ZhnDatasetDetect()
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
print('Dataset loading complete.')
device = torch.device('cuda:0' if use_cuda else 'cpu')
model = YoloV5(classify=train_backbone)
model.load_state_dict(torch.load('yolov5.pth'))
model.to(device)
freeze_param(model, exclude='classify' if train_backbone else 'bbr')
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0)
loss_function = torch.nn.CrossEntropyLoss() if train_backbone else total_loss
print('Network loading complete.')
average_loss = 0
step = 0
while True:
    for data in train_dataloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        predict = model(images)
        nan_cnt = torch.isnan(predict).sum()
        assert nan_cnt == 0
        train_loss = loss_function(predict, labels)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        average_loss += train_loss
        print('\rtrain loss={:.6f}'.format(train_loss), end='')
        step += 1
        if step >= 100:
            average_loss /= step
            print(f'\nAverage loss is {average_loss}.')
            average_loss = 0
            step = 0
            # torch.save(model.state_dict(), 'yolov5.pth')
            param_disturb(model, var=lr/50, device=device)
    print('\nepoch finished.')
